# Before running, install required packages:
{% if notebook %}

!
{%- else %}
#
{%- endif %}
 pip install numpy sklearn{% if data_format == "Image files" %} torchvision{% endif %}{% if visualization_tool == "Tensorboard" %} tensorboardX{% endif %}{% if visualization_tool == "comet.ml" %} comet_ml{% endif %}
{% if notebook %}


# ---
{% endif %}


{% if visualization_tool == "comet.ml" %}
from comet_ml import Experiment  # has to be 1st import
{% endif %}
import numpy as np
import sklearn
from {{ ".".join(model_func.split(".")[:-1]) }} import {{ model_func.split(".")[-1] }}
{% if data_format == "Image files" %}
from torchvision import datasets, transforms
import urllib
import zipfile
{% endif %}
{% if visualization_tool == "Tensorboard" %}
from tensorboardX import SummaryWriter
from datetime import datetime
{% endif %}

{% if data_format == "Numpy arrays" %}
def fake_data():
    # 4 images of shape 1x16x16 with labels 0, 1, 2, 3
    return [np.random.rand(4, 1, 16, 16), np.arange(4)]

{% elif data_format == "Image files" %}
# COMMENT THIS OUT IF YOU USE YOUR OWN DATA.
# Download example data into ./data/image-data (4 image files, 2 for "dog", 2 for "cat").
url = "https://github.com/jrieke/traingenerator/raw/main/data/fake-image-data.zip"
zip_path, _ = urllib.request.urlretrieve(url)
with zipfile.ZipFile(zip_path, "r") as f:
    f.extractall("data")
    
{% endif %}

{{ header("Setup") }}
{% if data_format == "Numpy arrays" %}
# INSERT YOUR DATA HERE
# Expected format: [images, labels]
# - images has array shape (num samples, color channels, height, width)
# - labels has array shape (num samples, )
train_data = fake_data()  # required
val_data = fake_data()    # optional
test_data = None          # optional
{% elif data_format == "Image files" %}
# INSERT YOUR DATA HERE
# Expected format: One folder per class, e.g.
# train
# --- dogs
# |   +-- lassie.jpg
# |   +-- komissar-rex.png
# --- cats
# |   +-- garfield.png
# |   +-- smelly-cat.png
#
# Example: https://github.com/jrieke/traingenerator/tree/main/data/image-data
train_data = "data/image-data"  # required
val_data = "data/image-data"    # optional
test_data = None                # optional
{% endif %}

{% if visualization_tool == "Tensorboard" %}
# Set up logging.
experiment_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
writer = SummaryWriter(logdir=f"logs/{experiment_id}")

{% elif visualization_tool == "comet.ml" %}
# Set up logging.
experiment = Experiment("{{ comet_api_key }}"{% if comet_project %}, project_name="{{ comet_project }}"{% endif %})

{% endif %}

{{ header("Preprocessing") }}
{% if scale_mean_std %}
# Set up scaler.
scaler = sklearn.preprocessing.StandardScaler()

{% endif %}
def preprocess(data, name):
    if data is None:  # val/test can be empty
        return None
    {% if data_format == "Image files" %}
    # Read image files to pytorch dataset (only temporary).
    transform = transforms.Compose([
        transforms.Resize({{ resize_pixels }}), 
        transforms.CenterCrop({{ crop_pixels }}), 
        transforms.ToTensor()
    ])
    data = datasets.ImageFolder(data, transform=transform)

    # Convert to numpy arrays.
    images_shape = (len(data), *data[0][0].shape)
    images = np.zeros(images_shape)
    labels = np.zeros(len(data))
    for i, (image, label) in enumerate(data):
        images[i] = image
        labels[i] = label
    
    {% elif data_format == "Numpy arrays" %}
    images, labels = data

    {% endif %}
    # Flatten.
    images = images.reshape(len(images), -1)

    {% if scale_mean_std %}
    # Scale to mean 0 and std 1.
    if name == "train":
        scaler.fit(images)
    images = scaler.transform(images)

    {% endif %}
    # Shuffle train set.
    if name == "train":
        images, labels = sklearn.utils.shuffle(images, labels)

    return [images, labels]

processed_train_data = preprocess(train_data, "train")
processed_val_data = preprocess(val_data, "val")
processed_test_data = preprocess(test_data, "test")


{{ header("Model") }}
model = {{ model_func.split(".")[-1] }}()


{{ header("Training") }}
def evaluate(data, name):
    if data is None:  # val/test can be empty
        return

    images, labels = data
    acc = model.score(images, labels)
    print(f"{name + ':':6} accuracy: {acc}")
    {% if visualization_tool == "Tensorboard" %}
    writer.add_scalar(f"{name}_accuracy", acc)
    {% elif visualization_tool == "comet.ml" %}
    experiment.log_metric(f"{name}_accuracy", acc)
    {% endif %}

# Train on train_data.
model.fit(*processed_train_data)

# Evaluate on all datasets.
evaluate(processed_train_data, "train")
evaluate(processed_val_data, "val")
evaluate(processed_test_data, "test")
